Consider to download this Jupyter Notebook and run locally, or test it with Colab.
In this tutorial, we will implement and discuss variants of modern CNN architectures. There have been many different architectures been proposed over the past few years. Some of the most impactful ones, and still relevant today, are the following: GoogleNet/Inception architecture (winner of ILSVRC 2014), ResNet (winner of ILSVRC 2015), and DenseNet (best paper award CVPR 2017). All of them were state-of-the-art models when being proposed, and the core ideas of these networks are the foundations for most current state-of-the-art architectures. Thus, it is important to understand these architectures in detail and learn how to implement them.
Let’s start with importing our standard libraries here.
Requirement already satisfied: pillow in /opt/anaconda3/lib/python3.12/site-packages (10.4.0)
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.12/site-packages (2.8.0)
Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.13.1)
Requirement already satisfied: typing-extensions>=4.10.0 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /opt/anaconda3/lib/python3.12/site-packages (from torch) (75.1.0)
Requirement already satisfied: sympy>=1.13.3 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /opt/anaconda3/lib/python3.12/site-packages (from torch) (2024.6.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.12/site-packages (from jinja2->torch) (2.1.3)
Requirement already satisfied: matplotlib in /opt/anaconda3/lib/python3.12/site-packages (3.9.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.23 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (24.1)
Requirement already satisfied: pillow>=8 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (10.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Requirement already satisfied: seaborn in /opt/anaconda3/lib/python3.12/site-packages (0.13.2)
Requirement already satisfied: numpy!=1.24.0,>=1.20 in /opt/anaconda3/lib/python3.12/site-packages (from seaborn) (1.26.4)
Requirement already satisfied: pandas>=1.2 in /opt/anaconda3/lib/python3.12/site-packages (from seaborn) (2.2.2)
Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /opt/anaconda3/lib/python3.12/site-packages (from seaborn) (3.9.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.4)
Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (24.1)
Requirement already satisfied: pillow>=8 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (10.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/anaconda3/lib/python3.12/site-packages (from pandas>=1.2->seaborn) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /opt/anaconda3/lib/python3.12/site-packages (from pandas>=1.2->seaborn) (2023.3)
Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib!=3.6.1,>=3.4->seaborn) (1.16.0)
Requirement already satisfied: tensorboard in /opt/anaconda3/lib/python3.12/site-packages (2.20.0)
Requirement already satisfied: absl-py>=0.4 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (2.3.1)
Requirement already satisfied: grpcio>=1.48.2 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (1.75.1)
Requirement already satisfied: markdown>=2.6.8 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (3.4.1)
Requirement already satisfied: numpy>=1.12.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (1.26.4)
Requirement already satisfied: packaging in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (24.1)
Requirement already satisfied: pillow in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (10.4.0)
Requirement already satisfied: protobuf!=4.24.0,>=3.19.6 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (4.25.3)
Requirement already satisfied: setuptools>=41.0.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (75.1.0)
Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (0.7.2)
Requirement already satisfied: werkzeug>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from tensorboard) (3.0.3)
Requirement already satisfied: typing-extensions~=4.12 in /opt/anaconda3/lib/python3.12/site-packages (from grpcio>=1.48.2->tensorboard) (4.15.0)
Requirement already satisfied: MarkupSafe>=2.1.1 in /opt/anaconda3/lib/python3.12/site-packages (from werkzeug>=1.0.1->tensorboard) (2.1.3)
Requirement already satisfied: ipywidgets in /opt/anaconda3/lib/python3.12/site-packages (7.8.1)
Requirement already satisfied: comm>=0.1.3 in /opt/anaconda3/lib/python3.12/site-packages (from ipywidgets) (0.2.1)
Requirement already satisfied: ipython-genutils~=0.2.0 in /opt/anaconda3/lib/python3.12/site-packages (from ipywidgets) (0.2.0)
Requirement already satisfied: traitlets>=4.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from ipywidgets) (5.14.3)
Requirement already satisfied: widgetsnbextension~=3.6.6 in /opt/anaconda3/lib/python3.12/site-packages (from ipywidgets) (3.6.6)
Requirement already satisfied: ipython>=4.0.0 in /opt/anaconda3/lib/python3.12/site-packages (from ipywidgets) (8.27.0)
Requirement already satisfied: jupyterlab-widgets<3,>=1.0.0 in /opt/anaconda3/lib/python3.12/site-packages (from ipywidgets) (1.0.0)
Requirement already satisfied: decorator in /opt/anaconda3/lib/python3.12/site-packages (from ipython>=4.0.0->ipywidgets) (5.1.1)
Requirement already satisfied: jedi>=0.16 in /opt/anaconda3/lib/python3.12/site-packages (from ipython>=4.0.0->ipywidgets) (0.19.1)
Requirement already satisfied: matplotlib-inline in /opt/anaconda3/lib/python3.12/site-packages (from ipython>=4.0.0->ipywidgets) (0.1.6)
Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /opt/anaconda3/lib/python3.12/site-packages (from ipython>=4.0.0->ipywidgets) (3.0.43)
Requirement already satisfied: pygments>=2.4.0 in /opt/anaconda3/lib/python3.12/site-packages (from ipython>=4.0.0->ipywidgets) (2.15.1)
Requirement already satisfied: stack-data in /opt/anaconda3/lib/python3.12/site-packages (from ipython>=4.0.0->ipywidgets) (0.2.0)
Requirement already satisfied: pexpect>4.3 in /opt/anaconda3/lib/python3.12/site-packages (from ipython>=4.0.0->ipywidgets) (4.8.0)
Requirement already satisfied: notebook>=4.4.1 in /opt/anaconda3/lib/python3.12/site-packages (from widgetsnbextension~=3.6.6->ipywidgets) (7.2.2)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /opt/anaconda3/lib/python3.12/site-packages (from jedi>=0.16->ipython>=4.0.0->ipywidgets) (0.8.3)
Requirement already satisfied: jupyter-server<3,>=2.4.0 in /opt/anaconda3/lib/python3.12/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.14.1)
Requirement already satisfied: jupyterlab-server<3,>=2.27.1 in /opt/anaconda3/lib/python3.12/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.27.3)
Requirement already satisfied: jupyterlab<4.3,>=4.2.0 in /opt/anaconda3/lib/python3.12/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (4.2.5)
Requirement already satisfied: notebook-shim<0.3,>=0.2 in /opt/anaconda3/lib/python3.12/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.2.3)
Requirement already satisfied: tornado>=6.2.0 in /opt/anaconda3/lib/python3.12/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (6.4.1)
Requirement already satisfied: ptyprocess>=0.5 in /opt/anaconda3/lib/python3.12/site-packages (from pexpect>4.3->ipython>=4.0.0->ipywidgets) (0.7.0)
Requirement already satisfied: wcwidth in /opt/anaconda3/lib/python3.12/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->ipython>=4.0.0->ipywidgets) (0.2.5)
Requirement already satisfied: executing in /opt/anaconda3/lib/python3.12/site-packages (from stack-data->ipython>=4.0.0->ipywidgets) (0.8.3)
Requirement already satisfied: asttokens in /opt/anaconda3/lib/python3.12/site-packages (from stack-data->ipython>=4.0.0->ipywidgets) (2.0.5)
Requirement already satisfied: pure-eval in /opt/anaconda3/lib/python3.12/site-packages (from stack-data->ipython>=4.0.0->ipywidgets) (0.2.2)
Requirement already satisfied: anyio>=3.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (4.2.0)
Requirement already satisfied: argon2-cffi>=21.1 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (21.3.0)
Requirement already satisfied: jinja2>=3.0.3 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (3.1.4)
Requirement already satisfied: jupyter-client>=7.4.4 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (8.6.0)
Requirement already satisfied: jupyter-core!=5.0.*,>=4.12 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (5.7.2)
Requirement already satisfied: jupyter-events>=0.9.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.10.0)
Requirement already satisfied: jupyter-server-terminals>=0.4.4 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.4.4)
Requirement already satisfied: nbconvert>=6.4.4 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (7.16.4)
Requirement already satisfied: nbformat>=5.3.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (5.10.4)
Requirement already satisfied: overrides>=5.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (7.4.0)
Requirement already satisfied: packaging>=22.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (24.1)
Requirement already satisfied: prometheus-client>=0.9 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.14.1)
Requirement already satisfied: pyzmq>=24 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (25.1.2)
Requirement already satisfied: send2trash>=1.8.2 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.8.2)
Requirement already satisfied: terminado>=0.8.3 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.17.1)
Requirement already satisfied: websocket-client>=1.7 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.8.0)
Requirement already satisfied: async-lru>=1.0.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.0.4)
Requirement already satisfied: httpx>=0.25.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.27.0)
Requirement already satisfied: ipykernel>=6.5.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (6.28.0)
Requirement already satisfied: jupyter-lsp>=2.0.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.2.0)
Requirement already satisfied: setuptools>=40.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (75.1.0)
Requirement already satisfied: babel>=2.10 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.11.0)
Requirement already satisfied: json5>=0.9.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.9.6)
Requirement already satisfied: jsonschema>=4.18.0 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (4.23.0)
Requirement already satisfied: requests>=2.31 in /opt/anaconda3/lib/python3.12/site-packages (from jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.32.3)
Requirement already satisfied: six in /opt/anaconda3/lib/python3.12/site-packages (from asttokens->stack-data->ipython>=4.0.0->ipywidgets) (1.16.0)
Requirement already satisfied: idna>=2.8 in /opt/anaconda3/lib/python3.12/site-packages (from anyio>=3.1.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (3.7)
Requirement already satisfied: sniffio>=1.1 in /opt/anaconda3/lib/python3.12/site-packages (from anyio>=3.1.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.3.0)
Requirement already satisfied: argon2-cffi-bindings in /opt/anaconda3/lib/python3.12/site-packages (from argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (21.2.0)
Requirement already satisfied: pytz>=2015.7 in /opt/anaconda3/lib/python3.12/site-packages (from babel>=2.10->jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2024.1)
Requirement already satisfied: certifi in /opt/anaconda3/lib/python3.12/site-packages (from httpx>=0.25.0->jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2024.8.30)
Requirement already satisfied: httpcore==1.* in /opt/anaconda3/lib/python3.12/site-packages (from httpx>=0.25.0->jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.0.2)
Requirement already satisfied: h11<0.15,>=0.13 in /opt/anaconda3/lib/python3.12/site-packages (from httpcore==1.*->httpx>=0.25.0->jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.14.0)
Requirement already satisfied: appnope in /opt/anaconda3/lib/python3.12/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.1.3)
Requirement already satisfied: debugpy>=1.6.5 in /opt/anaconda3/lib/python3.12/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.6.7)
Requirement already satisfied: nest-asyncio in /opt/anaconda3/lib/python3.12/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.6.0)
Requirement already satisfied: psutil in /opt/anaconda3/lib/python3.12/site-packages (from ipykernel>=6.5.0->jupyterlab<4.3,>=4.2.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (5.9.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.12/site-packages (from jinja2>=3.0.3->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.1.3)
Requirement already satisfied: attrs>=22.2.0 in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (23.1.0)
Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2023.7.1)
Requirement already satisfied: referencing>=0.28.4 in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.30.2)
Requirement already satisfied: rpds-py>=0.7.1 in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema>=4.18.0->jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.10.6)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-client>=7.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.9.0.post0)
Requirement already satisfied: platformdirs>=2.5 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-core!=5.0.*,>=4.12->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (3.10.0)
Requirement already satisfied: python-json-logger>=2.0.4 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.0.7)
Requirement already satisfied: pyyaml>=5.3 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (6.0.1)
Requirement already satisfied: rfc3339-validator in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.1.4)
Requirement already satisfied: rfc3986-validator>=0.1.1 in /opt/anaconda3/lib/python3.12/site-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.1.1)
Requirement already satisfied: beautifulsoup4 in /opt/anaconda3/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (4.12.3)
Requirement already satisfied: bleach!=5.0.0 in /opt/anaconda3/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (4.1.0)
Requirement already satisfied: defusedxml in /opt/anaconda3/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.7.1)
Requirement already satisfied: jupyterlab-pygments in /opt/anaconda3/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.1.2)
Requirement already satisfied: mistune<4,>=2.0.3 in /opt/anaconda3/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.0.4)
Requirement already satisfied: nbclient>=0.5.0 in /opt/anaconda3/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.8.0)
Requirement already satisfied: pandocfilters>=1.4.1 in /opt/anaconda3/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.5.0)
Requirement already satisfied: tinycss2 in /opt/anaconda3/lib/python3.12/site-packages (from nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.2.1)
Requirement already satisfied: fastjsonschema>=2.15 in /opt/anaconda3/lib/python3.12/site-packages (from nbformat>=5.3.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.16.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/anaconda3/lib/python3.12/site-packages (from requests>=2.31->jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/anaconda3/lib/python3.12/site-packages (from requests>=2.31->jupyterlab-server<3,>=2.27.1->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.2.3)
Requirement already satisfied: webencodings in /opt/anaconda3/lib/python3.12/site-packages (from bleach!=5.0.0->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (0.5.1)
Requirement already satisfied: fqdn in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.5.1)
Requirement already satisfied: isoduration in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (20.11.0)
Requirement already satisfied: jsonpointer>1.13 in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.1)
Requirement already satisfied: uri-template in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.3.0)
Requirement already satisfied: webcolors>=24.6.0 in /opt/anaconda3/lib/python3.12/site-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (24.11.1)
Requirement already satisfied: cffi>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.17.1)
Requirement already satisfied: soupsieve>1.2 in /opt/anaconda3/lib/python3.12/site-packages (from beautifulsoup4->nbconvert>=6.4.4->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.5)
Requirement already satisfied: pycparser in /opt/anaconda3/lib/python3.12/site-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi>=21.1->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (2.21)
Requirement already satisfied: arrow>=0.15.0 in /opt/anaconda3/lib/python3.12/site-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.6.6->ipywidgets) (1.2.3)
usage: jupyter [-h] [--version] [--config-dir] [--data-dir] [--runtime-dir]
[--paths] [--json] [--debug]
[subcommand]
Jupyter: Interactive Computing
positional arguments:
subcommand the subcommand to launch
options:
-h, --help show this help message and exit
--version show the versions of core jupyter packages and exit
--config-dir show Jupyter config dir
--data-dir show Jupyter data dir
--runtime-dir show Jupyter runtime dir
--paths show all Jupyter paths. Add --json for machine-readable
format.
--json output paths as machine-readable json
--debug output debug information about paths
Available subcommands: console dejavu events execute kernel kernelspec lab
labextension labhub migrate nbconvert notebook qtconsole run server
troubleshoot trust
Jupyter command `jupyter-nbextension` not found.
Requirement already satisfied: pytorch-lightning in /opt/anaconda3/lib/python3.12/site-packages (2.5.5)
Requirement already satisfied: torch>=2.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from pytorch-lightning) (2.8.0)
Requirement already satisfied: tqdm>=4.57.0 in /opt/anaconda3/lib/python3.12/site-packages (from pytorch-lightning) (4.66.5)
Requirement already satisfied: PyYAML>5.4 in /opt/anaconda3/lib/python3.12/site-packages (from pytorch-lightning) (6.0.1)
Requirement already satisfied: fsspec>=2022.5.0 in /opt/anaconda3/lib/python3.12/site-packages (from fsspec[http]>=2022.5.0->pytorch-lightning) (2024.6.1)
Requirement already satisfied: torchmetrics>0.7.0 in /opt/anaconda3/lib/python3.12/site-packages (from pytorch-lightning) (1.8.2)
Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.12/site-packages (from pytorch-lightning) (24.1)
Requirement already satisfied: typing-extensions>4.5.0 in /opt/anaconda3/lib/python3.12/site-packages (from pytorch-lightning) (4.15.0)
Requirement already satisfied: lightning-utilities>=0.10.0 in /opt/anaconda3/lib/python3.12/site-packages (from pytorch-lightning) (0.15.2)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /opt/anaconda3/lib/python3.12/site-packages (from fsspec[http]>=2022.5.0->pytorch-lightning) (3.10.5)
Requirement already satisfied: setuptools in /opt/anaconda3/lib/python3.12/site-packages (from lightning-utilities>=0.10.0->pytorch-lightning) (75.1.0)
Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.12/site-packages (from torch>=2.1.0->pytorch-lightning) (3.13.1)
Requirement already satisfied: sympy>=1.13.3 in /opt/anaconda3/lib/python3.12/site-packages (from torch>=2.1.0->pytorch-lightning) (1.14.0)
Requirement already satisfied: networkx in /opt/anaconda3/lib/python3.12/site-packages (from torch>=2.1.0->pytorch-lightning) (3.3)
Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.12/site-packages (from torch>=2.1.0->pytorch-lightning) (3.1.4)
Requirement already satisfied: numpy>1.20.0 in /opt/anaconda3/lib/python3.12/site-packages (from torchmetrics>0.7.0->pytorch-lightning) (1.26.4)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/anaconda3/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (2.4.0)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/anaconda3/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (1.2.0)
Requirement already satisfied: attrs>=17.3.0 in /opt/anaconda3/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (23.1.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/anaconda3/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (1.4.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/anaconda3/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (6.0.4)
Requirement already satisfied: yarl<2.0,>=1.0 in /opt/anaconda3/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (1.11.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from sympy>=1.13.3->torch>=2.1.0->pytorch-lightning) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.12/site-packages (from jinja2->torch>=2.1.0->pytorch-lightning) (2.1.3)
Requirement already satisfied: idna>=2.0 in /opt/anaconda3/lib/python3.12/site-packages (from yarl<2.0,>=1.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch-lightning) (3.7)
Requirement already satisfied: tqdm in /opt/anaconda3/lib/python3.12/site-packages (4.66.5)
## Standard librariesimport osimport numpy as npimport randomfrom PIL import Imageimport pytorch_lightning as plfrom types import SimpleNamespace## Imports for plottingimport matplotlib.pyplot as plt%matplotlib inlinefrom IPython.display import set_matplotlib_formatsset_matplotlib_formats('svg', 'pdf') # For exportimport matplotlibmatplotlib.rcParams['lines.linewidth'] =2.0import seaborn as snssns.reset_orig()## PyTorchimport torchimport torch.nn as nnimport torch.utils.data as dataimport torch.optim as optim# Torchvisionimport torchvisionfrom torchvision.datasets import CIFAR10from torchvision import transformsfrom tqdm import tqdm
/var/folders/nl/7_2jcxd12wb5z06jvsj1v4240000gn/T/ipykernel_51652/3185902794.py:13: DeprecationWarning: `set_matplotlib_formats` is deprecated since IPython 7.23, directly use `matplotlib_inline.backend_inline.set_matplotlib_formats()`
set_matplotlib_formats('svg', 'pdf') # For export
We will use the same set_seed function as in the previous tutorials, as well as the path variables DATASET_PATH and CHECKPOINT_PATH. Adjust the paths if necessary.
# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)DATASET_PATH ="../data"# Path to the folder where the pretrained models are savedCHECKPOINT_PATH ="../saved_models/tutorial"# Function for setting the seeddef set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed)if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed)set_seed(42)# Ensure that all operations are deterministic on GPU (if used) for reproducibilitytorch.backends.cudnn.deterministic =Truetorch.backends.cudnn.benchmark =Falsedevice = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
We also have pretrained models and Tensorboards (more on this later) for this tutorial, and download them below.
import urllib.requestfrom urllib.error import HTTPError# Github URL where saved models are stored for this tutorialbase_url ="https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial5/"# Files to downloadpretrained_files = ["ResNet.ckpt", "ResNetPreAct.ckpt","tensorboards/ResNet/events.out.tfevents.resnet","tensorboards/ResNetPreAct/events.out.tfevents.resnetpreact"]# Create checkpoint path if it doesn't exist yetos.makedirs(CHECKPOINT_PATH, exist_ok=True)# For each file, check whether it already exists. If not, try downloading it.for file_name in pretrained_files: file_path = os.path.join(CHECKPOINT_PATH, file_name)if"/"in file_name: os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)ifnot os.path.isfile(file_path): file_url = base_url + file_nameprint(f"Downloading {file_url}...")try: urllib.request.urlretrieve(file_url, file_path)except HTTPError as e:print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)
Throughout this tutorial, we will train and evaluate the models on the CIFAR10 dataset. This allows you to compare the results obtained here with the model you have implemented in the first assignment. As we have learned from the previous tutorial about initialization, it is important to have the data preprocessed with a zero mean. Therefore, as a first step, we will calculate the mean and standard deviation of the CIFAR dataset:
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]
We will use this information to define a transforms.Normalize module which will normalize our data accordingly. Additionally, we will use data augmentation during training. This reduces the risk of overfitting and helps CNNs to generalize better. Specifically, we will apply two random augmentations.
First, we will flip each image horizontally by a chance of 50% (transforms.RandomHorizontalFlip). The object class usually does not change when flipping an image, and we don’t expect any image information to be dependent on the horizontal orientation. This would be however different if we would try to detect digits or letters in an image, as those have a certain orientation.
The second augmentation we use is called transforms.RandomResizedCrop. This transformation crops the image in a small range, eventually changing the aspect ratio, and scaling it back afterward to the previous size. Therefore, the actual pixel values change while the content or overall semantics of the image stays the same.
We will randomly split the training dataset into a training and a validation set. The validation set will be used for determining early stopping. After finishing the training, we test the models on the CIFAR test set.
test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(DATA_MEANS, DATA_STD) ])# For training, we add some augmentation. Networks are too powerful and would overfit.train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomResizedCrop((32,32), scale=(0.8,1.0), ratio=(0.9,1.1)), transforms.ToTensor(), transforms.Normalize(DATA_MEANS, DATA_STD) ])# Loading the training dataset. We need to split it into a training and validation part# We need to do a little trick because the validation set should not use the augmentation.train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)set_seed(42)train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])set_seed(42)_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])# Loading the test settest_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)# We define a set of data loaders that we can use for various purposes later.train_loader = data.DataLoader(train_set, batch_size=128, shuffle=True, drop_last=True, pin_memory=True, num_workers=4)val_loader = data.DataLoader(val_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)test_loader = data.DataLoader(test_set, batch_size=128, shuffle=False, drop_last=False, num_workers=4)
To verify that our normalization works, we can print out the mean and standard deviation of the single batch. The mean should be close to 0 and the standard deviation close to 1 for each channel:
/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
warnings.warn(warn_msg)
Batch mean tensor([0.0231, 0.0006, 0.0005])
Batch std tensor([0.9865, 0.9849, 0.9868])
Finally, let’s visualize a few images from the training set, and how they look like after random data augmentation:
NUM_IMAGES =4images = [train_dataset[idx][0] for idx inrange(NUM_IMAGES)]orig_images = [Image.fromarray(train_dataset.data[idx]) for idx inrange(NUM_IMAGES)]orig_images = [test_transform(img) for img in orig_images]img_grid = torchvision.utils.make_grid(torch.stack(images + orig_images, dim=0), nrow=4, normalize=True, pad_value=0.5)img_grid = img_grid.permute(1, 2, 0)plt.figure(figsize=(8,8))plt.title("Augmentation examples on CIFAR10")plt.imshow(img_grid)plt.axis('off')plt.show()plt.close()
PyTorch Lightning
In this notebook and in many following ones, we will make use of the library PyTorch Lightning. PyTorch Lightning is a framework that simplifies your code needed to train, evaluate, and test a model in PyTorch. It also handles logging into TensorBoard, a visualization toolkit for ML experiments, and saving model checkpoints automatically with minimal code overhead from our side. This is extremely helpful for us as we want to focus on implementing different model architectures and spend little time on other code overhead. Note that at the time of writing/teaching, the framework has been released in version 1.8. Future versions might have a slightly changed interface and thus might not work perfectly with the code (we will try to keep it up-to-date as much as possible).
Now, we will take the first step in PyTorch Lightning, and continue to explore the framework in our other tutorials. First, we import the library:
PyTorch Lightning comes with a lot of useful functions, such as one for setting the seed:
# Setting the seedpl.seed_everything(42)
Seed set to 42
42
Thus, in the future, we don’t have to define our own set_seed function anymore.
In PyTorch Lightning, we define pl.LightningModule’s (inheriting from torch.nn.Module) that organize our code into 5 main sections:
Initialization (__init__), where we create all necessary parameters/models
Optimizers (configure_optimizers) where we create the optimizers, learning rate scheduler, etc.
Training loop (training_step) where we only have to define the loss calculation for a single batch (the loop of optimizer.zero_grad(), loss.backward() and optimizer.step(), as well as any logging/saving operation, is done in the background)
Validation loop (validation_step) where similarly to the training, we only have to define what should happen per step
Test loop (test_step) which is the same as validation, only on a test set.
Therefore, we don’t abstract the PyTorch code, but rather organize it and define some default operations that are commonly used. If you need to change something else in your training/validation/test loop, there are many possible functions you can overwrite (see the docs for details).
Now we can look at an example of how a Lightning Module for training a CNN looks like:
class CIFARModule(pl.LightningModule):def__init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):""" Inputs: model_name - Name of the model/CNN to run. Used for creating the model (see function below) model_hparams - Hyperparameters for the model, as dictionary. optimizer_name - Name of the optimizer to use. Currently supported: Adam, SGD optimizer_hparams - Hyperparameters for the optimizer, as dictionary. This includes learning rate, weight decay, etc. """super().__init__()# Exports the hyperparameters to a YAML file, and create "self.hparams" namespaceself.save_hyperparameters()# Create modelself.model = create_model(model_name, model_hparams)# Create loss moduleself.loss_module = nn.CrossEntropyLoss()# Example input for visualizing the graph in Tensorboardself.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)def forward(self, imgs):# Forward function that is run when visualizing the graphreturnself.model(imgs)def configure_optimizers(self):# We will support Adam or SGD as optimizers.ifself.hparams.optimizer_name =="Adam":# AdamW is Adam with a correct implementation of weight decay (see here for details: https://arxiv.org/pdf/1711.05101.pdf) optimizer = optim.AdamW(self.parameters(), **self.hparams.optimizer_hparams)elifself.hparams.optimizer_name =="SGD": optimizer = optim.SGD(self.parameters(), **self.hparams.optimizer_hparams)else:assertFalse, f"Unknown optimizer: \"{self.hparams.optimizer_name}\""# We will reduce the learning rate by 0.1 after 100 and 150 epochs scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[100, 150], gamma=0.1)return [optimizer], [scheduler]def training_step(self, batch, batch_idx):# "batch" is the output of the training data loader. imgs, labels = batch preds =self.model(imgs) loss =self.loss_module(preds, labels) acc = (preds.argmax(dim=-1) == labels).float().mean()# Logs the accuracy per epoch to tensorboard (weighted average over batches)self.log('train_acc', acc, on_step=False, on_epoch=True)self.log('train_loss', loss)return loss # Return tensor to call ".backward" ondef validation_step(self, batch, batch_idx): imgs, labels = batch preds =self.model(imgs).argmax(dim=-1) acc = (labels == preds).float().mean()# By default logs it per epoch (weighted average over batches)self.log('val_acc', acc)def test_step(self, batch, batch_idx): imgs, labels = batch preds =self.model(imgs).argmax(dim=-1) acc = (labels == preds).float().mean()# By default logs it per epoch (weighted average over batches), and returns it afterwardsself.log('test_acc', acc)
We see that the code is organized and clear, which helps if someone else tries to understand your code.
Another important part of PyTorch Lightning is the concept of callbacks. Callbacks are self-contained functions that contain the non-essential logic of your Lightning Module. They are usually called after finishing a training epoch, but can also influence other parts of your training loop. For instance, we will use the following two pre-defined callbacks: LearningRateMonitor and ModelCheckpoint. The learning rate monitor adds the current learning rate to our TensorBoard, which helps to verify that our learning rate scheduler works correctly. The model checkpoint callback allows you to customize the saving routine of your checkpoints. For instance, how many checkpoints to keep, when to save, which metric to look out for, etc. We import them below:
To allow running multiple different models with the same Lightning module, we define a function below that maps a model name to the model class. At this stage, the dictionary model_dict is empty, but we will fill it throughout the notebook with our new models.
model_dict = {}def create_model(model_name, model_hparams):if model_name in model_dict:return model_dict[model_name](**model_hparams)else:assertFalse, f"Unknown model name \"{model_name}\". Available models are: {str(model_dict.keys())}"
Similarly, to use the activation function as another hyperparameter in our model, we define a “name to function” dict below:
If we pass the classes or objects directly as an argument to the Lightning module, we couldn’t take advantage of PyTorch Lightning’s automatically hyperparameter saving and loading.
Besides the Lightning module, the second most important module in PyTorch Lightning is the Trainer. The trainer is responsible to execute the training steps defined in the Lightning module and completes the framework. Similar to the Lightning module, you can override any key part that you don’t want to be automated, but the default settings are often the best practice to do. For a full overview, see the documentation. The most important functions we use below are:
trainer.fit: Takes as input a lightning module, a training dataset, and an (optional) validation dataset. This function trains the given module on the training dataset with occasional validation (default once per epoch, can be changed)
trainer.test: Takes as input a model and a dataset on which we want to test. It returns the test metric on the dataset.
For training and testing, we don’t have to worry about things like setting the model to eval mode (model.eval()) as this is all done automatically. See below how we define a training function for our models:
def train_model(model_name, save_name=None, **kwargs):""" Inputs: model_name - Name of the model you want to run. Is used to look up the class in "model_dict" save_name (optional) - If specified, this name will be used for creating the checkpoint and logging directory. """if save_name isNone: save_name = model_name# Create a PyTorch Lightning trainer with the generation callback trainer = pl.Trainer(default_root_dir=os.path.join(CHECKPOINT_PATH, save_name), # Where to save models accelerator="gpu"ifstr(device).startswith("cuda") else"cpu", # We run on a GPU (if possible) devices=1, # How many GPUs/CPUs we want to use (1 is enough for the notebooks) max_epochs=180, # How many epochs to train for if no patience is set callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc"), # Save the best checkpoint based on the maximum val_acc recorded. Saves only weights and not optimizer LearningRateMonitor("epoch")], # Log learning rate every epoch enable_progress_bar=True) # Set to False if you do not want a progress bar trainer.logger._log_graph =True# If True, we plot the computation graph in tensorboard trainer.logger._default_hp_metric =None# Optional logging argument that we don't need# Check whether pretrained model exists. If yes, load it and skip training pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name +".ckpt")if os.path.isfile(pretrained_filename):print(f"Found pretrained model at {pretrained_filename}, loading...") model = CIFARModule.load_from_checkpoint(pretrained_filename) # Automatically loads the model with the saved hyperparameterselse: pl.seed_everything(42) # To be reproducable model = CIFARModule(model_name=model_name, **kwargs) trainer.fit(model, train_loader, val_loader) model = CIFARModule.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training# Test best model on validation and test set val_result = trainer.test(model, val_loader, verbose=False) test_result = trainer.test(model, test_loader, verbose=False) result = {"test": test_result[0]["test_acc"], "val": val_result[0]["test_acc"]}return model, result
Finally, we can focus on the Convolutional Neural Networks we want to implement today: ResNet.
ResNet
The ResNet paper is one of the most cited AI papers, and has been the foundation for neural networks with more than 1,000 layers. Despite its simplicity, the idea of residual connections is highly effective as it supports stable gradient propagation through the network. Instead of modeling \(x_{l+1}=F(x_{l})\), we model \(x_{l+1}=x_{l}+F(x_{l})\) where \(F\) is a non-linear mapping (usually a sequence of NN modules likes convolutions, activation functions, and normalizations). If we do backpropagation on such residual connections, we obtain:
The bias towards the identity matrix guarantees a stable gradient propagation being less effected by \(F\) itself. There have been many variants of ResNet proposed, which mostly concern the function \(F\), or operations applied on the sum. In this tutorial, we look at two of them: the original ResNet block, and the Pre-Activation ResNet block. We visually compare the blocks below (figure credit - He et al.):
The original ResNet block applies a non-linear activation function, usually ReLU, after the skip connection. In contrast, the pre-activation ResNet block applies the non-linearity at the beginning of \(F\). Both have their advantages and disadvantages. For very deep network, however, the pre-activation ResNet has shown to perform better as the gradient flow is guaranteed to have the identity matrix as calculated above, and is not harmed by any non-linear activation applied to it. For comparison, in this notebook, we implement both ResNet types as shallow networks.
Let’s start with the original ResNet block. The visualization above already shows what layers are included in \(F\). One special case we have to handle is when we want to reduce the image dimensions in terms of width and height. The basic ResNet block requires \(F(x_{l})\) to be of the same shape as \(x_{l}\). Thus, we need to change the dimensionality of \(x_{l}\) as well before adding to \(F(x_{l})\). The original implementation used an identity mapping with stride 2 and padded additional feature dimensions with 0. However, the more common implementation is to use a 1x1 convolution with stride 2 as it allows us to change the feature dimensionality while being efficient in parameter and computation cost. The code for the ResNet block is relatively simple, and shown below:
class ResNetBlock(nn.Module):def__init__(self, c_in, act_fn, subsample=False, c_out=-1):""" Inputs: c_in - Number of input features act_fn - Activation class constructor (e.g. nn.ReLU) subsample - If True, we want to apply a stride inside the block and reduce the output shape by 2 in height and width c_out - Number of output features. Note that this is only relevant if subsample is True, as otherwise, c_out = c_in """super().__init__()ifnot subsample: c_out = c_in# Network representing Fself.net = nn.Sequential( nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=1ifnot subsample else2, bias=False), # No bias needed as the Batch Norm handles it nn.BatchNorm2d(c_out), act_fn(), nn.Conv2d(c_out, c_out, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(c_out) )# 1x1 convolution with stride 2 means we take the upper left value, and transform it to new output sizeself.downsample = nn.Conv2d(c_in, c_out, kernel_size=1, stride=2) if subsample elseNoneself.act_fn = act_fn()def forward(self, x): z =self.net(x)ifself.downsample isnotNone: x =self.downsample(x) out = z + x out =self.act_fn(out)return out
The second block we implement is the pre-activation ResNet block. For this, we have to change the order of layer in self.net, and do not apply an activation function on the output. Additionally, the downsampling operation has to apply a non-linearity as well as the input, \(x_l\), has not been processed by a non-linearity yet. Hence, the block looks as follows:
class PreActResNetBlock(nn.Module):def__init__(self, c_in, act_fn, subsample=False, c_out=-1):""" Inputs: c_in - Number of input features act_fn - Activation class constructor (e.g. nn.ReLU) subsample - If True, we want to apply a stride inside the block and reduce the output shape by 2 in height and width c_out - Number of output features. Note that this is only relevant if subsample is True, as otherwise, c_out = c_in """super().__init__()ifnot subsample: c_out = c_in# Network representing Fself.net = nn.Sequential( nn.BatchNorm2d(c_in), act_fn(), nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=1ifnot subsample else2, bias=False), nn.BatchNorm2d(c_out), act_fn(), nn.Conv2d(c_out, c_out, kernel_size=3, padding=1, bias=False) )# 1x1 convolution can apply non-linearity as well, but not strictly necessaryself.downsample = nn.Sequential( nn.BatchNorm2d(c_in), act_fn(), nn.Conv2d(c_in, c_out, kernel_size=1, stride=2, bias=False) ) if subsample elseNonedef forward(self, x): z =self.net(x)ifself.downsample isnotNone: x =self.downsample(x) out = z + xreturn out
Similarly to the model selection, we define a dictionary to create a mapping from string to block class. We will use the string name as hyperparameter value in our model to choose between the ResNet blocks. Feel free to implement any other ResNet block type and add it here as well.
The overall ResNet architecture consists of stacking multiple ResNet blocks, of which some are downsampling the input. When talking about ResNet blocks in the whole network, we usually group them by the same output shape. Hence, if we say the ResNet has [3,3,3] blocks, it means that we have 3 times a group of 3 ResNet blocks, where a subsampling is taking place in the fourth and seventh block. The ResNet with [3,3,3] blocks on CIFAR10 is visualized below.
The three groups operate on the resolutions \(32\times32\), \(16\times16\) and \(8\times8\) respectively. The blocks in orange denote ResNet blocks with downsampling. The same notation is used by many other implementations such as in the torchvision library from PyTorch. Thus, our code looks as follows:
class ResNet(nn.Module):def__init__(self, num_classes=10, num_blocks=[3,3,3], c_hidden=[16,32,64], act_fn_name="relu", block_name="ResNetBlock", **kwargs):""" Inputs: num_classes - Number of classification outputs (10 for CIFAR10) num_blocks - List with the number of ResNet blocks to use. The first block of each group uses downsampling, except the first. c_hidden - List with the hidden dimensionalities in the different blocks. Usually multiplied by 2 the deeper we go. act_fn_name - Name of the activation function to use, looked up in "act_fn_by_name" block_name - Name of the ResNet block, looked up in "resnet_blocks_by_name" """super().__init__()assert block_name in resnet_blocks_by_nameself.hparams = SimpleNamespace(num_classes=num_classes, c_hidden=c_hidden, num_blocks=num_blocks, act_fn_name=act_fn_name, act_fn=act_fn_by_name[act_fn_name], block_class=resnet_blocks_by_name[block_name])self._create_network()self._init_params()def _create_network(self): c_hidden =self.hparams.c_hidden# A first convolution on the original image to scale up the channel sizeifself.hparams.block_class == PreActResNetBlock: # => Don't apply non-linearity on outputself.input_net = nn.Sequential( nn.Conv2d(3, c_hidden[0], kernel_size=3, padding=1, bias=False) )else:self.input_net = nn.Sequential( nn.Conv2d(3, c_hidden[0], kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(c_hidden[0]),self.hparams.act_fn() )# Creating the ResNet blocks blocks = []for block_idx, block_count inenumerate(self.hparams.num_blocks):for bc inrange(block_count): subsample = (bc ==0and block_idx >0) # Subsample the first block of each group, except the very first one. blocks.append(self.hparams.block_class(c_in=c_hidden[block_idx ifnot subsample else (block_idx-1)], act_fn=self.hparams.act_fn, subsample=subsample, c_out=c_hidden[block_idx]) )self.blocks = nn.Sequential(*blocks)# Mapping to classification outputself.output_net = nn.Sequential( nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(), nn.Linear(c_hidden[-1], self.hparams.num_classes) )def _init_params(self):# Based on our discussion in Tutorial 4, we should initialize the convolutions according to the activation function# Fan-out focuses on the gradient distribution, and is commonly used in ResNetsfor m inself.modules():ifisinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity=self.hparams.act_fn_name)elifisinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)def forward(self, x): x =self.input_net(x) x =self.blocks(x) x =self.output_net(x)return x
We also need to add the new ResNet class to our model dictionary:
model_dict["ResNet"] = ResNet
Finally, we can train our ResNet models. One difference to the GoogleNet training is that we explicitly use SGD with Momentum as optimizer instead of Adam. Adam often leads to a slightly worse accuracy on plain, shallow ResNets. It is not 100% clear why Adam performs worse in this context, but one possible explanation is related to ResNet’s loss surface. ResNet has been shown to produce smoother loss surfaces than networks without skip connection (see Li et al., 2018 for details). A possible visualization of the loss surface with/out skip connections is below (figure credit - Li et al.):
The \(x\) and \(y\) axis shows a projection of the parameter space, and the \(z\) axis shows the loss values achieved by different parameter values. On smooth surfaces like the one on the right, we might not require an adaptive learning rate as Adam provides. Instead, Adam can get stuck in local optima while SGD finds the wider minima that tend to generalize better. However, to answer this question in detail, we would need an extra tutorial because it is not easy to answer. For now, we conclude: for ResNet architectures, consider the optimizer to be an important hyperparameter, and try training with both Adam and SGD. Let’s train the model below with SGD:
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/opt/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.5.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../saved_models/tutorial/ResNet.ckpt`
/opt/anaconda3/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:428: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.
Found pretrained model at ../saved_models/tutorial/ResNet.ckpt, loading...
Let’s also train the pre-activation ResNet as comparison:
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v2.5.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../saved_models/tutorial/ResNetPreAct.ckpt`
Found pretrained model at ../saved_models/tutorial/ResNetPreAct.ckpt, loading...
Tensorboard log
A nice extra of PyTorch Lightning is the automatic logging into TensorBoard. To give you a better intuition of what TensorBoard can be used, we can look at the board that PyTorch Lightning has been generated when training the ResNet. TensorBoard provides an inline functionality for Jupyter notebooks, and we use it here:
# Opens tensorboard in notebook. Adjust the path to your CHECKPOINT_PATH! Feel free to change "ResNet" to "ResNetPreAct"%load_ext tensorboard%tensorboard --logdir ../saved_models/tutorial/tensorboards/ResNet/
The tensorboard extension is already loaded. To reload it, use:
%reload_ext tensorboard
Reusing TensorBoard on port 6006 (pid 51871), started 0:19:38 ago. (Use '!kill 51871' to kill it.)